import os
import math
import numpy
import time
from KG_parameter import KG
import codecs
import re
import argparse
from util import load_pickle, create_co_matrix, threshold_cooccur, sppmi, most_similar
import Parameter
import svm_emb

from scipy import sparse

numpy.set_printoptions(suppress=True,formatter={'float_kind':'{:.20f}'.format})
import sys
numpy.set_printoptions(threshold=sys.maxsize)


def normalize1(X, dim):
    tmp = numpy.sum(X, axis=dim)
    X = (X.transpose() / tmp).transpose()
    return X


def findMax(lil_m):
    coo_m = lil_m.tocoo()
    k = coo_m.data.argmax()
    return coo_m.data[k]


def readDict(dict_path):
    file = codecs.open(dict_path, 'r', 'utf-8')
    dict = [line.strip() for line in file] 
    file.close()

    word2id = {}
    id2word = {}
    Id = 0

    for word in dict:
        word2id[word] = Id
        id2word[Id] = word
        Id += 1
    
    return id2word, word2id


def topT(U_t, T):
    topic = []
    topic_word = U_t.transpose()
    for i in range(len(topic_word)):
        topic_now = numpy.argsort(-topic_word[i]).tolist()[0:T]
        topic.append(topic_now)
    return topic


def CLLM(domain, num_docs_chunk, KGraph, total_id2word, total_word2id, option, Dt_total, Xt_total, St_total):
    # hyper-parameters settings
    parameterNow = [option.lambda_u1, option.lambda_c, option.lambda_b1, option.lambda_b2, option.eta]

    num_topic = option.topic_num
    num_embedding = option.emb_size
    topnum = option.topT
    eps = 0.0001
    lambda_v = option.lambda_v
    lambda_u2 = option.lambda_u2
    theta = 1

    lambda_u1 = parameterNow[0]
    lambda_c = parameterNow[1]
    lambda_b1 = parameterNow[2]
    lambda_b2 = parameterNow[3]
    eta = parameterNow[4]
    R_k = option.R_k # threshold ratio for KG
    iterations = option.iter_num

    for t in range(num_docs_chunk):
        domain_start_time = time.time()

        # load vocab
        dict_path_t = os.path.join(option.datadir, "Dict", "Id2word_" + str(domain[t]) + ".txt")
        current_id2word, current_word2id = readDict(dict_path_t)

        print('current domain =', domain[t])
        Dt_origin = Dt_total[domain[t]]
        Dt = Dt_origin.transpose()
        print("Dt:ok")

        Xt = Xt_total[domain[t]]
        #Xt = create_sppmi(t)
        print("Xt:ok")

        num_document = Dt.shape[1]
        num_word = Dt.shape[0]
        
        St = St_total[domain[t]]

        E1E2 = numpy.ones((num_topic, num_document),dtype='int32')

        E1E22 = numpy.ones((num_word, num_embedding),dtype='int32')

        Ut = normalize1(numpy.random.rand(num_word, num_topic),1)
        Vt = normalize1(numpy.random.rand(num_topic, num_document),1)

        Bt = numpy.random.rand(num_word, num_embedding)
        Ct = numpy.random.rand(num_word, num_embedding)

        W = sparse.eye(num_word, num_word, dtype='float', format='lil')

        current_KGraph = sparse.lil_matrix((num_word, num_word), dtype='float')

        nnz_indices = KGraph.nonzero()
        for r_ind, c_ind in zip(nnz_indices[0], nnz_indices[1]):
            wordx = total_id2word[r_ind]
            wordy = total_id2word[c_ind]
            if wordx in current_word2id.keys() and wordy in current_word2id.keys():
                index1 = current_word2id[wordx]
                index2 = current_word2id[wordy]
                current_KGraph[index1, index2] = KGraph[r_ind, c_ind]

        if current_KGraph.nnz > 0:
            maxweight = findMax(current_KGraph)
            print('max weight in current KG =', maxweight)
            threshold2 = maxweight * R_k

        countforW = 0
        nnz_indices = current_KGraph.nonzero()
        for r_ind, c_ind in zip(nnz_indices[0], nnz_indices[1]):
            if r_ind == c_ind:
                continue
            if current_KGraph[r_ind, c_ind] > threshold2:
                countforW += 1
                W[r_ind, c_ind] = current_KGraph[r_ind, c_ind] / maxweight
        
        print("W size: ", countforW)

        #diag(W_(t-1)1)
        diagW = numpy.zeros((num_word, num_word), dtype='float32')
        W_row_sum = W.sum(axis=1)
        for i in range(num_word):
            diagW[i][i] = W_row_sum[i, 0]

        #diag(St1)
        diagS = numpy.zeros((num_word, num_word), dtype='float32')
        S_row_sum = St.sum(axis=1)
        for i in range(num_word):
            diagS[i][i] = S_row_sum[i, 0]

        print('process W&S cost = %d' % (time.time()-domain_start_time))
        update_start_time = time.time()

        for times in range(iterations):
            if times > 0 and times % 50 == 0:
                print('update [%d] cost = %d' % (times, time.time()-update_start_time))
            #print(times)
            #Update U_t
            numerator = Dt.dot(Vt.transpose()) + lambda_u1 * W.dot(Ut) + 2 * lambda_u2 * Ut
            denominator = Ut.dot(Vt.dot(Vt.transpose())) + lambda_u1 * diagW.dot(Ut) + 2 * lambda_u2 * Ut.dot(Ut.transpose().dot(Ut))
            Ut = Ut * ((eps + numerator) / (eps + denominator))

            #Update V_t
            numerator = Ut.transpose().dot(Dt)
            denominator = Ut.transpose().dot(Ut.dot(Vt)) + lambda_v / 2 * E1E2
            Vt = Vt * ((eps + numerator) / (eps + denominator))

            #Update C_t
            numerator = theta * Xt.transpose().dot(Bt) + lambda_c * W.dot(Ct)
            denominator = theta * Ct.dot(Bt.transpose().dot(Bt)) + lambda_c * diagW.dot(Ct)
            Ct = Ct * ((eps + numerator) / (eps + denominator))

            #Update B_t
            numerator = theta * Xt.dot(Ct) + lambda_b1 * St.dot(Bt)
            denominator = theta * Bt.dot(Ct.transpose().dot(Ct)) + lambda_b1 * diagS.dot(Bt) + lambda_b2 / 2 * E1E22
            Bt = Bt * ((eps + numerator) / (eps + denominator))

            '''
            UtTUt = Ut.transpose().dot(Ut)
            J1 = numpy.linalg.norm(Dt - Ut.dot(Vt))
            J2 = numpy.linalg.norm(Xt - Bt.dot(Ct.transpose()))
            J3 = numpy.linalg.norm(Vt, ord = 1)
            J4 = numpy.trace(Ut.transpose().dot(diagW.dot(Ut)) - Ut.transpose().dot(W.dot(Ut)))
            J5 = numpy.trace(UtTUt.dot(UtTUt) - 2 * UtTUt)
            J6 = numpy.trace(Ct.transpose().dot(diagW.dot(Ct)) - Ct.transpose().dot(W.dot(Ct)))
            J7 = numpy.trace(Bt.transpose().dot(diagS.dot(Bt)) - Bt.transpose().dot(St.dot(Bt)))

            cost = J1 + theta * J2 + lambda_v * J3 + lambda_u1 * J4 + lambda_u2 * J5 + lambda_c * J6 + lambda_b1 * J7
            
            print("J1: ", J1)
            print("J2: ", J2)
            print("J3: ", J3)
            print("J4: ", J4)
            print("J5: ", J5)
            print("J6: ", J6)
            print("J7: ", J7)
            print("total cost: ", cost)
            '''

        print('update cost = %d' % (time.time()-update_start_time))
        KG_start_time = time.time()

        # update KG
        KG(KGraph, Ut, Ct, topnum, total_id2word, total_word2id, current_id2word, current_word2id, eta)
        
        print('update KG cost = %d' % (time.time()-KG_start_time))
        save_start_time = time.time()

        topics = topT(Ut, 20)
        if domain[t] == 6 or domain[t] == 7:
            # save KG
            sparse.save_npz(os.path.join(option.savedir, "KG_" + str(parameterNow) + "_" + str(domain[t]) + ".npz"), KGraph.tocoo())
        
        accuracy_svm = 0
        proportion = 0

        if domain[t] >= 0:
            f1 = open(os.path.join(option.savedir, "embedding.vec"), "w", encoding='utf-8')
            lis = []
            with open(os.path.join(option.datadir, "Dict", "Id2word_" + str(domain[t]) + ".txt"), "r", encoding='utf-8') as f:
                for line in f:
                    lis.append(line.strip("\n"))

            f1.write(str(num_word) + " "+str(num_embedding)+"\n")

            for i in range(len(lis)):
                f1.write(lis[i])
                f1.write(" ")
                f1.write(str(numpy.core.defchararray.replace(str(Bt[i]),'\n', ''))[1:].rstrip("]"))
                f1.write("\n")
            f1.close()
            count = 0
            for i in range(len(Bt)):
                for j in range(Bt.shape[1]):
                    if Bt[i][j] < 0.00000000000000000001:
                        count += 1
            total = len(Bt) * Bt.shape[1]
            proportion = count / total
            accuracy_svm = svm_emb.perform_svm(domain_name=str(domain[t]), data_dir=option.datadir, word_emb_path=os.path.join(option.savedir, "embedding.vec"), word_emb_size=option.emb_size, use_pretrained=True)
        

        if domain[t] == 7 or domain[t] == 8:
            f1 = open(os.path.join(option.savedir, "embedding_"+ str(domain[t]) + "_" + ".vec"), "w", encoding='utf-8')
            lis = []
            with open(os.path.join(option.datadir, "Dict", "Id2word_" + str(domain[t]) + ".txt"), "r", encoding='utf-8') as f:
                for line in f:
                    lis.append(line.strip("\n"))

            f1.write(str(num_word) + " "+str(num_embedding)+"\n")

            for i in range(len(lis)):
                f1.write(lis[i])
                f1.write(" ")
                f1.write(str(numpy.core.defchararray.replace(str(Bt[i]),'\n', ''))[1:].rstrip("]"))
                f1.write("\n")
            f1.close()

        print('save cost = %d' % (time.time()-save_start_time))
        print('domain[%d] total cost = %d' % (t, time.time()-domain_start_time))

        yield t, topics, Dt_origin, Ut.transpose(), Vt.transpose(), accuracy_svm, proportion
